{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 第六章：卷积神经网络\n",
    "湖北理工学院《机器学习》课程资料\n",
    "\n",
    "作者：李辉楚吴\n",
    "\n",
    "笔记内容概述: 迁移学习、ResNet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torchvision\n",
    "import torchvision.transforms as transforms\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "label_size = 18 # Label size\n",
    "ticklabel_size = 14 # Tick label size\n",
    "    \n",
    "# Define a transform to normalize the data\n",
    "transform = transforms.Compose([\n",
    "    transforms.ToTensor()\n",
    "])\n",
    "\n",
    "# Load test data from the MNIST\n",
    "testset = torchvision.datasets.CIFAR10(root='./Data', train=False, download=False, transform=transform)\n",
    "print(f\"Test set size: {len(testset)}\")\n",
    "\n",
    "# Load training data from the MNIST\n",
    "trainset = torchvision.datasets.CIFAR10(root='./Data', train=True, download=False, transform=transform)\n",
    "print(f\"Training set size: {len(trainset)}\")\n",
    "\n",
    "# Rate of trX and cvX\n",
    "tr_cv_rate = 0.8\n",
    "\n",
    "# Create a list to store indices for each class unique()\n",
    "class_indices = [[] for _ in range(10)]  # 10 classes in MNIST\n",
    "\n",
    "# Populate class_indices\n",
    "for idx, (_, label) in enumerate(trainset):\n",
    "    class_indices[label].append(idx)\n",
    "\n",
    "# Calculate the number of samples for each class in training and validation sets\n",
    "train_size_per_class = int(tr_cv_rate * min(len(indices) for indices in class_indices))\n",
    "val_size_per_class = min(len(indices) for indices in class_indices) - train_size_per_class\n",
    "\n",
    "# Create balanced train and validation sets\n",
    "train_indices = []\n",
    "val_indices = []\n",
    "for indices in class_indices:\n",
    "    train_indices.extend(indices[:train_size_per_class])\n",
    "    val_indices.extend(indices[train_size_per_class:train_size_per_class + val_size_per_class])\n",
    "\n",
    "# Create Subset datasets\n",
    "from torch.utils.data import Subset\n",
    "trX = Subset(trainset, train_indices)\n",
    "cvX = Subset(trainset, val_indices)\n",
    "\n",
    "print(f\"Number of training samples: {len(trX)}\")\n",
    "print(f\"Number of cross-validation samples: {len(cvX)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "构建DataLoaders，准备训练模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 64\n",
    "\n",
    "def one_hot_collate(batch):\n",
    "    data = torch.stack([item[0] for item in batch])\n",
    "    labels = torch.tensor([item[1] for item in batch])\n",
    "    one_hot_labels = torch.zeros(labels.size(0), 10)  # 10 classes in MNIST 【0，1，0，0】\n",
    "    one_hot_labels.scatter_(1, labels.unsqueeze(1), 1)\n",
    "    return data, one_hot_labels\n",
    "\n",
    "trLoader = torch.utils.data.DataLoader(trX, batch_size=batch_size, shuffle=True, num_workers=0, collate_fn=one_hot_collate)\n",
    "cvLoader = torch.utils.data.DataLoader(cvX, batch_size=batch_size, shuffle=False, num_workers=0, collate_fn=one_hot_collate)\n",
    "teLoader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=0, collate_fn=one_hot_collate)\n",
    "\n",
    "# Get a batch of training data\n",
    "dataiter = iter(trLoader)\n",
    "data, labels = next(dataiter)\n",
    "\n",
    "image_channels = data[0].numpy().shape[0]\n",
    "print(f'image_channels is {image_channels}')\n",
    "print(labels[0,:])\n",
    "\n",
    "# Label text of CIFAR-10\n",
    "label_text = ['airplane', 'automobile', 'bird', 'cat', 'deer', \n",
    "              'dog', 'frog', 'horse', 'ship', 'truck']\n",
    "\n",
    "# Plot one image from the batch\n",
    "plt.figure(figsize=(6, 6))\n",
    "# Modify the imshow line to handle RGB images correctly\n",
    "plt.imshow(data[0].permute(1, 2, 0).numpy())  # Rearrange from (3,32,32) to (32,32,3)\n",
    "plt.title(f'Label: {label_text[labels[0].argmax().item()]}')\n",
    "plt.axis('off')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 迁移ResNet微调FNN层"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn as nn\n",
    "\n",
    "# 1. 加载预训练模型\n",
    "model = torchvision.models.resnet18(pretrained=True)\n",
    "print(model)\n",
    "# 2. 修改输入层 (因为 MNIST 是单通道图像)\n",
    "model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)\n",
    "\n",
    "# 3. 移除第一层Maxpooling避免参数过早消失\n",
    "model.maxpool = nn.Identity() # nn.Conv2d(64, 64, 1, 1, 1)\n",
    "\n",
    "# 4. 修改输出层 (根据任务的类别数)\n",
    "model.fc = nn.Linear(model.fc.in_features, 10)  # 10为MNIST的类别数\n",
    "\n",
    "# 打印模型结构\n",
    "print(model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "微调ResNet18"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 定义损失函数\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "\n",
    "# 只优化未冻结的参数\n",
    "# optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()))\n",
    "optimizer = torch.optim.Adam(model.parameters())\n",
    "\n",
    "# 训练模型\n",
    "num_epochs = 5\n",
    "train_losses = []\n",
    "cv_losses = []\n",
    "\n",
    "for epoch in range(num_epochs):\n",
    "    model.train()\n",
    "    for images, labels in trLoader:\n",
    "        optimizer.zero_grad()\n",
    "        outputs = model(images)\n",
    "        loss = criterion(outputs, labels)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        \n",
    "        print(f'Epoch {epoch+1}/{num_epochs}, Training Loss: {loss.item():.4f}')\n",
    "        \n",
    "    train_losses.append(loss.item())\n",
    "\n",
    "    # 计算交叉验证损失\n",
    "    model.eval()\n",
    "    cv_loss = 0.0\n",
    "    with torch.no_grad():\n",
    "        for images, labels in cvLoader:\n",
    "            outputs = model(images)\n",
    "            loss = criterion(outputs, labels)\n",
    "            cv_loss += loss.item()\n",
    "    cv_losses.append(cv_loss / len(cvLoader))\n",
    "\n",
    "    print(f'Epoch {epoch+1}/{num_epochs}, Training Loss: {train_losses[-1]:.4f}, Cross-Validation Loss: {cv_losses[-1]:.4f}')\n",
    "\n",
    "# 保存模型\n",
    "torch.save(model.state_dict(), 'mnist_resnet18_finetuned.pth')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "计算识别精度，展示学习曲线"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Calculate and print accuracies for training and cross-validation sets\n",
    "model.eval()\n",
    "with torch.no_grad():\n",
    "    # Training set accuracy\n",
    "    tr_correct = 0\n",
    "    tr_total = 0\n",
    "    for images, labels in trLoader:\n",
    "        outputs = model(images)\n",
    "        _, predicted = torch.max(outputs, 1)\n",
    "        _, true_labels = torch.max(labels, 1)\n",
    "        tr_total += labels.size(0)\n",
    "        tr_correct += (predicted == true_labels).sum().item()\n",
    "    \n",
    "    tr_accuracy = 100 * tr_correct / tr_total\n",
    "    \n",
    "    # Test set accuracy\n",
    "    cv_correct = 0\n",
    "    cv_total = 0\n",
    "    for images, labels in cvLoader:\n",
    "        outputs = model(images)\n",
    "        _, predicted = torch.max(outputs, 1)\n",
    "        _, true_labels = torch.max(labels, 1)\n",
    "        cv_total += labels.size(0)\n",
    "        cv_correct += (predicted == true_labels).sum().item()\n",
    "    \n",
    "    cv_accuracy = 100 * cv_correct / cv_total\n",
    "\n",
    "print(f'Accuracy on training set: {tr_accuracy:.2f}%')\n",
    "print(f'Accuracy on cross-validation set: {cv_accuracy:.2f}%')\n",
    "\n",
    "# Plot training and cross-validation losses\n",
    "plt.figure(figsize=(10, 5))\n",
    "plt.plot(range(1, num_epochs+1), train_losses, label='Training Loss')\n",
    "plt.plot(range(1, num_epochs+1), cv_losses, label='Cross-Validation Loss')\n",
    "plt.xlabel('Epoch')\n",
    "plt.ylabel('Loss')\n",
    "plt.title('Training and Cross-Validation Loss')\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "计算测试精度"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.eval()\n",
    "with torch.no_grad():\n",
    "    test_correct = 0\n",
    "    test_total = 0\n",
    "    for images, labels in teLoader:\n",
    "        outputs = model(images)\n",
    "        _, predicted = torch.max(outputs, 1)\n",
    "        _, true_labels = torch.max(labels, 1)\n",
    "        test_total += labels.size(0)\n",
    "        test_correct += (predicted == true_labels).sum().item()\n",
    "    test_accuracy = 100 * test_correct / test_total\n",
    "    print(f'Accuracy on test set: {test_accuracy:.2f}%')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "----------------------------------------------------------\n",
    "## 实验11 基于迁移学习的图像识别\n",
    "一、\t实验目的\n",
    "1.\t了解迁移学习的概念\n",
    "2.\t能够使用迁移学习修改ResNet网络\n",
    "3.\t针对新问题微调ResNet网络\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import DataLoader\n",
    "from torchvision import datasets, transforms, models\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# 设置设备\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "# 数据预处理\n",
    "def preprocess_data(data_dir=\"./data\", img_size=(224, 224)):\n",
    "    transform = transforms.Compose([\n",
    "        transforms.Resize(img_size),\n",
    "        transforms.ToTensor(),\n",
    "        transforms.Normalize((0.5,), (0.5,))\n",
    "    ])\n",
    "    train_data = datasets.ImageFolder(root=f\"{data_dir}/train\", transform=transform)\n",
    "    test_data = datasets.ImageFolder(root=f\"{data_dir}/test\", transform=transform)\n",
    "    return train_data, test_data\n",
    "\n",
    "# 微调ResNet模型\n",
    "def modify_resnet18(num_classes):\n",
    "    model = models.resnet18(pretrained=True)\n",
    "    print(\"原始ResNet18模型：\")\n",
    "    print(model)\n",
    "    # 冻结第一卷积层\n",
    "    for name, param in model.named_parameters():\n",
    "        if \"conv1\" in name:\n",
    "            param.requires_grad = False\n",
    "\n",
    "    # 替换输出层\n",
    "    num_ftrs = model.fc.in_features\n",
    "    model.fc = nn.Linear(num_ftrs, num_classes)\n",
    "    model = model.to(device)\n",
    "    print(\"修改后的ResNet18模型：\")\n",
    "    print(model)\n",
    "    return model\n",
    "\n",
    "# 模型训练函数\n",
    "def train_model(model, train_loader, val_loader, epochs, learning_rate):\n",
    "    optimizer = optim.Adam(model.parameters(), lr=learning_rate)\n",
    "    criterion = nn.CrossEntropyLoss()\n",
    "    train_loss_history, val_loss_history = [], []\n",
    "\n",
    "    for epoch in range(epochs):\n",
    "        model.train()\n",
    "        train_loss = 0\n",
    "        for inputs, labels in train_loader:\n",
    "            inputs, labels = inputs.to(device), labels.to(device)\n",
    "            optimizer.zero_grad()\n",
    "            outputs = model(inputs)\n",
    "            loss = criterion(outputs, labels)\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            train_loss += loss.item()\n",
    "\n",
    "        train_loss /= len(train_loader)\n",
    "        train_loss_history.append(train_loss)\n",
    "\n",
    "        model.eval()\n",
    "        val_loss = 0\n",
    "        with torch.no_grad():\n",
    "            for inputs, labels in val_loader:\n",
    "                inputs, labels = inputs.to(device), labels.to(device)\n",
    "                outputs = model(inputs)\n",
    "                loss = criterion(outputs, labels)\n",
    "                val_loss += loss.item()\n",
    "        val_loss /= len(val_loader)\n",
    "        val_loss_history.append(val_loss)\n",
    "\n",
    "        print(f\"Epoch {epoch + 1}/{epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}\")\n",
    "\n",
    "    return train_loss_history, val_loss_history\n",
    "\n",
    "# 模型测试函数\n",
    "def test_model(model, test_loader):\n",
    "    model.eval()\n",
    "    correct, total = 0, 0\n",
    "    with torch.no_grad():\n",
    "        for inputs, labels in test_loader:\n",
    "            inputs, labels = inputs.to(device), labels.to(device)\n",
    "            outputs = model(inputs)\n",
    "            _, predicted = torch.max(outputs, 1)\n",
    "            total += labels.size(0)\n",
    "            correct += (predicted == labels).sum().item()\n",
    "    accuracy = correct / total\n",
    "    print(f\"Test Accuracy: {accuracy:.4f}\")\n",
    "    return accuracy\n",
    "\n",
    "# 可视化学习曲线\n",
    "def plot_learning_curve(train_loss, val_loss):\n",
    "    plt.plot(train_loss, label='Train Loss')\n",
    "    plt.plot(val_loss, label='Validation Loss')\n",
    "    plt.xlabel('Epochs')\n",
    "    plt.ylabel('Loss')\n",
    "    plt.legend()\n",
    "    plt.title('Learning Curve')\n",
    "    plt.show()\n",
    "\n",
    "# 主函数\n",
    "if __name__ == \"__main__\":\n",
    "    # 图像数据路径及预处理\n",
    "    data_dir = \"./image_data\"  # 替换为你的数据集路径\n",
    "    img_size = (224, 224)\n",
    "    train_data, test_data = preprocess_data(data_dir, img_size)\n",
    "\n",
    "    train_loader = DataLoader(train_data, batch_size=32, shuffle=True)\n",
    "    test_loader = DataLoader(test_data, batch_size=32, shuffle=False)\n",
    "\n",
    "    # 模型初始化\n",
    "    num_classes = len(train_data.classes)\n",
    "    model = modify_resnet18(num_classes=num_classes)\n",
    "\n",
    "    # 训练模型\n",
    "    epochs = 10\n",
    "    learning_rate = 0.001\n",
    "    train_loss, val_loss = train_model(model, train_loader, test_loader, epochs, learning_rate)\n",
    "\n",
    "    # 可视化学习曲线\n",
    "    plot_learning_curve(train_loss, val_loss)\n",
    "\n",
    "    # 测试模型\n",
    "    test_model(model, test_loader)\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "machinelearning",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
